{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# {func}`tvm.te.tag.tag_scope`\n",
    "\n",
    "{func}`~tvm.te.tag.tag_scope` 函数，它接受字符串参数 `tag` 并返回类型为 {class}`~tvm.te.tag.TagScope` 的对象。{class}`~tvm.te.tag.TagScope` 类用于创建具有特定标签的算子的作用域，使它们能够轻松地被识别和管理。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tvm\n",
    "from tvm import te"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "以 `with` 管理器的形式构建："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div class=\"highlight\" style=\"background: \"><pre style=\"line-height: 125%;\"><span></span><span style=\"color: #007979; font-style: italic\"># from tvm.script import tir as T</span>\n",
       "\n",
       "<span style=\"color: #AA22FF\">@T</span><span style=\"color: #AA22FF; font-weight: bold\">.</span>prim_func\n",
       "<span style=\"color: #008000; font-weight: bold\">def</span> <span style=\"color: #0000FF\">main</span>(var_A: T<span style=\"color: #AA22FF; font-weight: bold\">.</span>handle, var_B: T<span style=\"color: #AA22FF; font-weight: bold\">.</span>handle, var_compute: T<span style=\"color: #AA22FF; font-weight: bold\">.</span>handle):\n",
       "    T<span style=\"color: #AA22FF; font-weight: bold\">.</span>func_attr({<span style=\"color: #BA2121\">&quot;tir.noalias&quot;</span>: <span style=\"color: #008000; font-weight: bold\">True</span>})\n",
       "    n, l <span style=\"color: #AA22FF; font-weight: bold\">=</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>int32(), T<span style=\"color: #AA22FF; font-weight: bold\">.</span>int32()\n",
       "    A <span style=\"color: #AA22FF; font-weight: bold\">=</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>match_buffer(var_A, (n, l))\n",
       "    m <span style=\"color: #AA22FF; font-weight: bold\">=</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>int32()\n",
       "    B <span style=\"color: #AA22FF; font-weight: bold\">=</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>match_buffer(var_B, (m, l))\n",
       "    compute <span style=\"color: #AA22FF; font-weight: bold\">=</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>match_buffer(var_compute, (n, m))\n",
       "    <span style=\"color: #007979; font-style: italic\"># with T.block(&quot;root&quot;):</span>\n",
       "    <span style=\"color: #008000; font-weight: bold\">for</span> i, j, k <span style=\"color: #008000; font-weight: bold\">in</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>grid(n, m, l):\n",
       "        <span style=\"color: #008000; font-weight: bold\">with</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>block(<span style=\"color: #BA2121\">&quot;compute&quot;</span>):\n",
       "            v_i, v_j, v_k <span style=\"color: #AA22FF; font-weight: bold\">=</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>axis<span style=\"color: #AA22FF; font-weight: bold\">.</span>remap(<span style=\"color: #BA2121\">&quot;SSR&quot;</span>, [i, j, k])\n",
       "            T<span style=\"color: #AA22FF; font-weight: bold\">.</span>reads(A[v_i, v_k], B[v_j, v_k])\n",
       "            T<span style=\"color: #AA22FF; font-weight: bold\">.</span>writes(compute[v_i, v_j])\n",
       "            <span style=\"color: #008000; font-weight: bold\">with</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>init():\n",
       "                compute[v_i, v_j] <span style=\"color: #AA22FF; font-weight: bold\">=</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>float32(<span style=\"color: #008000\">0.0</span>)\n",
       "            compute[v_i, v_j] <span style=\"color: #AA22FF; font-weight: bold\">=</span> compute[v_i, v_j] <span style=\"color: #AA22FF; font-weight: bold\">+</span> A[v_i, v_k] <span style=\"color: #AA22FF; font-weight: bold\">*</span> B[v_j, v_k]\n",
       "</pre></div>\n"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "n = te.var('n')\n",
    "m = te.var('m')\n",
    "l = te.var('l')\n",
    "A = te.placeholder((n, l), name='A')\n",
    "B = te.placeholder((m, l), name='B')\n",
    "k = te.reduce_axis((0, l), name='k')\n",
    "with tvm.te.tag_scope(tag='matmul'):\n",
    "    C = te.compute((n, m), lambda i, j: te.sum(A[i, k] * B[j, k], axis=k))\n",
    "te.create_prim_func([A, B, C]).show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "或者使用装饰器的方式构建："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<class 'tvm.te.tensor.Tensor'>\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div class=\"highlight\" style=\"background: \"><pre style=\"line-height: 125%;\"><span></span><span style=\"color: #007979; font-style: italic\"># from tvm.script import tir as T</span>\n",
       "\n",
       "<span style=\"color: #AA22FF\">@T</span><span style=\"color: #AA22FF; font-weight: bold\">.</span>prim_func\n",
       "<span style=\"color: #008000; font-weight: bold\">def</span> <span style=\"color: #0000FF\">main</span>(data: T<span style=\"color: #AA22FF; font-weight: bold\">.</span>Buffer((<span style=\"color: #008000\">2</span>,), <span style=\"color: #BA2121\">&quot;float32&quot;</span>), compute: T<span style=\"color: #AA22FF; font-weight: bold\">.</span>Buffer((<span style=\"color: #008000\">2</span>,), <span style=\"color: #BA2121\">&quot;float32&quot;</span>)):\n",
       "    T<span style=\"color: #AA22FF; font-weight: bold\">.</span>func_attr({<span style=\"color: #BA2121\">&quot;tir.noalias&quot;</span>: <span style=\"color: #008000; font-weight: bold\">True</span>})\n",
       "    <span style=\"color: #007979; font-style: italic\"># with T.block(&quot;root&quot;):</span>\n",
       "    <span style=\"color: #008000; font-weight: bold\">for</span> i0 <span style=\"color: #008000; font-weight: bold\">in</span> range(<span style=\"color: #008000\">2</span>):\n",
       "        <span style=\"color: #008000; font-weight: bold\">with</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>block(<span style=\"color: #BA2121\">&quot;compute&quot;</span>):\n",
       "            v_i0 <span style=\"color: #AA22FF; font-weight: bold\">=</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>axis<span style=\"color: #AA22FF; font-weight: bold\">.</span>spatial(<span style=\"color: #008000\">2</span>, i0)\n",
       "            T<span style=\"color: #AA22FF; font-weight: bold\">.</span>reads(data[v_i0])\n",
       "            T<span style=\"color: #AA22FF; font-weight: bold\">.</span>writes(compute[v_i0])\n",
       "            compute[v_i0] <span style=\"color: #AA22FF; font-weight: bold\">=</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>Select(data[v_i0] <span style=\"color: #AA22FF; font-weight: bold\">&lt;</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>float32(<span style=\"color: #008000\">0.0</span>), T<span style=\"color: #AA22FF; font-weight: bold\">.</span>float32(<span style=\"color: #008000\">0.0</span>), data[v_i0])\n",
       "</pre></div>\n"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "from tvm.topi import tag\n",
    "@tvm.te.tag_scope(tag=tag.ELEMWISE)\n",
    "def compute_relu(data):\n",
    "    \"\"\"计算 data relu 值\n",
    "\n",
    "    Parameters\n",
    "    ----------\n",
    "    data : tvm.te.Tensor\n",
    "        Input argument.\n",
    "\n",
    "    Returns\n",
    "    -------\n",
    "    y : tvm.te.Tensor\n",
    "        The result.\n",
    "    \"\"\"\n",
    "    print(type(data))\n",
    "    return te.compute(data.shape, lambda *i: tvm.tir.Select(data(*i) < 0, 0.0, data(*i)))\n",
    "\n",
    "data = te.placeholder(shape=(2,), dtype=\"float32\", name=\"data\")\n",
    "out = compute_relu(data)\n",
    "te.create_prim_func([data, out]).show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## tag_scope conv"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "@tvm.te.tag_scope(tag=\"conv\")\n",
    "def compute_conv(data, weight):\n",
    "    N, IC, H, W = data.shape\n",
    "    OC, IC, KH, KW = weight.shape\n",
    "    OH = H - KH + 1\n",
    "    OW = W - KW + 1\n",
    "\n",
    "    ic = te.reduce_axis((0, IC), name=\"ic\")\n",
    "    dh = te.reduce_axis((0, KH), name=\"dh\")\n",
    "    dw = te.reduce_axis((0, KW), name=\"dw\")\n",
    "\n",
    "    return te.compute(\n",
    "        (N, OC, OH, OW),\n",
    "        lambda i, oc, h, w: te.sum(\n",
    "            data[i, ic, h + dh, w + dw] * weight[oc, ic, dh, dw], axis=[ic, dh, dw]\n",
    "        ),\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "n = te.size_var(\"n\")\n",
    "m = te.size_var(\"m\")\n",
    "l = te.size_var(\"l\")\n",
    "\n",
    "A = te.placeholder((n, l), name=\"A\")\n",
    "B = te.placeholder((m, l), name=\"B\")\n",
    "with tvm.te.tag_scope(tag=\"gemm\"):\n",
    "    k = te.reduce_axis((0, l), name=\"k\")\n",
    "    C = te.compute(\n",
    "        (n, m),\n",
    "        lambda i, j: te.sum(A[i, k] * B[j, k], axis=k),\n",
    "        attrs={\"hello\": 1, \"arr\": [10, 12]},\n",
    "    )\n",
    "\n",
    "assert C.op.tag == \"gemm\"\n",
    "assert \"hello\" in C.op.attrs\n",
    "assert \"xx\" not in C.op.attrs\n",
    "assert C.op.attrs[\"hello\"] == 1\n",
    "CC = tvm.ir.load_json(tvm.ir.save_json(C))\n",
    "assert CC.op.attrs[\"hello\"] == 1\n",
    "assert len(CC.op.attrs[\"arr\"]) == 2\n",
    "assert CC.op.attrs[\"arr\"][0] == 10\n",
    "assert CC.op.attrs[\"arr\"][1] == 12"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "n = te.size_var(\"n\")\n",
    "c = te.size_var(\"c\")\n",
    "h = te.size_var(\"h\")\n",
    "w = te.size_var(\"w\")\n",
    "kh = te.size_var(\"kh\")\n",
    "kw = te.size_var(\"kw\")\n",
    "\n",
    "A = te.placeholder((n, c, h, w), name=\"A\")\n",
    "B = te.placeholder((c, c, kh, kw), name=\"B\")\n",
    "C = compute_conv(A, B)\n",
    "assert C.op.tag == \"conv\"\n",
    "assert len(C.op.attrs) == 0"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "嵌套："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/media/pc/data/lxw/ai/tvm/python/tvm/te/tag.py:50: UserWarning: Tag 'conv' declared via TagScope was not used.\n",
      "  warnings.warn(f\"Tag '{self.tag}' declared via TagScope was not used.\")\n"
     ]
    }
   ],
   "source": [
    "n = te.size_var(\"n\")\n",
    "c = te.size_var(\"c\")\n",
    "h = te.size_var(\"h\")\n",
    "w = te.size_var(\"w\")\n",
    "kh = te.size_var(\"kh\")\n",
    "kw = te.size_var(\"kw\")\n",
    "\n",
    "A = te.placeholder((n, c, h, w), name=\"A\")\n",
    "B = te.placeholder((c, c, kh, kw), name=\"B\")\n",
    "try:\n",
    "    with te.tag_scope(tag=\"conv\"):\n",
    "        C = compute_conv(A, B)\n",
    "    assert False\n",
    "except ValueError:\n",
    "    pass"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "ai",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
